import os
import time


def get_ckpt_save_path(train_mode, model_name, data_name, JR_lamda):
    path = []
    try:
        if train_mode == "adv":
            path = "{}_{}_adv_2".format(model_name, data_name)
        if train_mode == "fat":
            path = "{}_{}_fat".format(model_name, data_name)
        elif train_mode == "std":
            path = "{}_{}_std".format(model_name, data_name)
        elif train_mode == "small_scale":
            path = "{}_{}_small_scale".format(model_name, data_name)
        elif train_mode == "grad_align":
            path = "{}_{}_{}_grad_align".format(model_name, data_name, JR_lamda)
        elif train_mode == "ex_fgsm":
            path = "{}_{}_ex_fgsm".format(model_name, data_name)
        elif train_mode == "fgsm":
            path = "{}_{}_fgsm".format(model_name, data_name)
        else:
            raise ValueError("{} is not a train mode".format(train_mode))
    except ValueError as e:
        print("引发异常：", repr(e))
    prefix = 'checkpoint/' + path + '/'
    if not os.path.isdir(prefix):
        os.makedirs(prefix)
    # path = prefix
    path = time.strftime(prefix + '%m%d_%H:%M:%S')
    return path


def get_ckpt_load_path(train_mode, model_name, data_name, JR_lamda):
    path = []
    try:
        if train_mode == "adv_10":
            path = "{}_{}_adv".format(model_name, data_name)
        if train_mode == "adv_10":
                path = "{}_{}_adv_10".format(model_name, data_name)
        if train_mode == "adv_2":
            path = "{}_{}_adv_2".format(model_name, data_name)
        if train_mode == "fat":
            path = "{}_{}_fat".format(model_name, data_name)
        elif train_mode == "std":
            path = "{}_{}_std".format(model_name, data_name)
        elif train_mode == "small_scale":
            path = "{}_{}_small_scale".format(model_name, data_name)
        elif train_mode == "grad_align":
            path = "{}_{}_{}_grad_align".format(model_name, data_name, JR_lamda)
        elif train_mode == "ex_fgsm":
            path = "{}_{}_ex_fgsm".format(model_name, data_name)
        elif train_mode == "fgsm":
            path = "{}_{}_fgsm".format(model_name, data_name)
        else:
            raise ValueError("{} is not a train mode".format(train_mode))
    except ValueError as e:
        print("引发异常：", repr(e))
    path = os.path.join("checkpoint", path, 'best_adv')#/best_acc')
    return path


def get_detail_load_path(train_mode, model_name, data_name, JR_lamda):
    path = []
    try:
        if train_mode == "adv":
            path = "{}_{}_adv_10".format(model_name, data_name)
        if train_mode == "fat":
            path = "{}_{}_fat".format(model_name, data_name)
        elif train_mode == "std":
            path = "{}_{}_std".format(model_name, data_name)
        elif train_mode == "small_scale":
            path = "{}_{}_small_scale".format(model_name, data_name)
        elif train_mode == "grad_align":
            path = "{}_{}_{}_grad_align".format(model_name, data_name, JR_lamda)
        elif train_mode == "ex_fgsm":
            path = "{}_{}_ex_fgsm".format(model_name, data_name)
        elif train_mode == "fgsm":
            path = "{}_{}_fgsm".format(model_name, data_name)
        else:
            raise ValueError("{} is not a train mode".format(train_mode))
    except ValueError as e:
        print("引发异常：", repr(e))
    path = os.path.join("checkpoint", path, 'detail')
    return path
